-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[mlir][affine] fix the issue of celidiv mul ceildiv expression not satisfying commutative #109382
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
@llvm/pr-subscribers-llvm-adt @llvm/pr-subscribers-mlir-affine Author: long.chen (lipracer) ChangesFixs #107508 Full diff: https://github.com/llvm/llvm-project/pull/109382.diff 3 Files Affected:
diff --git a/llvm/include/llvm/ADT/ScopeExit.h b/llvm/include/llvm/ADT/ScopeExit.h
index 2f13fb65d34d80..7e126479df3a14 100644
--- a/llvm/include/llvm/ADT/ScopeExit.h
+++ b/llvm/include/llvm/ADT/ScopeExit.h
@@ -31,13 +31,14 @@ template <typename Callable> class scope_exit {
template <typename Fp>
explicit scope_exit(Fp &&F) : ExitFunction(std::forward<Fp>(F)) {}
- scope_exit(scope_exit &&Rhs)
- : ExitFunction(std::move(Rhs.ExitFunction)), Engaged(Rhs.Engaged) {
- Rhs.release();
- }
+ scope_exit(scope_exit &&Rhs) { *this = std::move(Rhs); }
scope_exit(const scope_exit &) = delete;
- scope_exit &operator=(scope_exit &&) = delete;
scope_exit &operator=(const scope_exit &) = delete;
+ scope_exit &operator=(scope_exit &&Rhs) {
+ Engaged = std::exchange(Rhs.Engaged, false);
+ ExitFunction = std::exchange(Rhs.ExitFunction, {});
+ return *this;
+ }
void release() { Engaged = false; }
diff --git a/mlir/lib/IR/AffineExpr.cpp b/mlir/lib/IR/AffineExpr.cpp
index fc7ede279643ed..84af1f11045d6d 100644
--- a/mlir/lib/IR/AffineExpr.cpp
+++ b/mlir/lib/IR/AffineExpr.cpp
@@ -9,6 +9,8 @@
#include <cmath>
#include <cstdint>
#include <limits>
+#include <numeric>
+#include <optional>
#include <utility>
#include "AffineExprDetail.h"
@@ -18,9 +20,8 @@
#include "mlir/IR/IntegerSet.h"
#include "mlir/Support/TypeID.h"
#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/ScopeExit.h"
#include "llvm/Support/MathExtras.h"
-#include <numeric>
-#include <optional>
using namespace mlir;
using namespace mlir::detail;
@@ -362,54 +363,119 @@ static bool isDivisibleBySymbol(AffineExpr expr, unsigned symbolPos,
assert((opKind == AffineExprKind::Mod || opKind == AffineExprKind::FloorDiv ||
opKind == AffineExprKind::CeilDiv) &&
"unexpected opKind");
- switch (expr.getKind()) {
- case AffineExprKind::Constant:
- return cast<AffineConstantExpr>(expr).getValue() == 0;
- case AffineExprKind::DimId:
- return false;
- case AffineExprKind::SymbolId:
- return (cast<AffineSymbolExpr>(expr).getPosition() == symbolPos);
- // Checks divisibility by the given symbol for both operands.
- case AffineExprKind::Add: {
- AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
- return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, opKind) &&
- isDivisibleBySymbol(binaryExpr.getRHS(), symbolPos, opKind);
- }
- // Checks divisibility by the given symbol for both operands. Consider the
- // expression `(((s1*s0) floordiv w) mod ((s1 * s2) floordiv p)) floordiv s1`,
- // this is a division by s1 and both the operands of modulo are divisible by
- // s1 but it is not divisible by s1 always. The third argument is
- // `AffineExprKind::Mod` for this reason.
- case AffineExprKind::Mod: {
- AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
- return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos,
- AffineExprKind::Mod) &&
- isDivisibleBySymbol(binaryExpr.getRHS(), symbolPos,
- AffineExprKind::Mod);
- }
- // Checks if any of the operand divisible by the given symbol.
- case AffineExprKind::Mul: {
- AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
- return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, opKind) ||
- isDivisibleBySymbol(binaryExpr.getRHS(), symbolPos, opKind);
- }
- // Floordiv and ceildiv are divisible by the given symbol when the first
- // operand is divisible, and the affine expression kind of the argument expr
- // is same as the argument `opKind`. This can be inferred from commutative
- // property of floordiv and ceildiv operations and are as follow:
- // (exp1 floordiv exp2) floordiv exp3 = (exp1 floordiv exp3) floordiv exp2
- // (exp1 ceildiv exp2) ceildiv exp3 = (exp1 ceildiv exp3) ceildiv expr2
- // It will fail if operations are not same. For example:
- // (exps1 ceildiv exp2) floordiv exp3 can not be simplified.
- case AffineExprKind::FloorDiv:
- case AffineExprKind::CeilDiv: {
- AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
- if (opKind != expr.getKind())
- return false;
- return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, expr.getKind());
- }
+ std::vector<std::tuple<AffineExpr, unsigned, AffineExprKind,
+ llvm::detail::scope_exit<std::function<void(void)>>>>
+ stack;
+ stack.emplace_back(expr, symbolPos, opKind, []() {});
+ bool result = false;
+
+ while (!stack.empty()) {
+ AffineExpr expr = std::get<0>(stack.back());
+ unsigned symbolPos = std::get<1>(stack.back());
+ AffineExprKind opKind = std::get<2>(stack.back());
+
+ switch (expr.getKind()) {
+ case AffineExprKind::Constant: {
+ // Note: Assignment must occur before pop, which will affect whether it
+ // enters other execution branches.
+ result = cast<AffineConstantExpr>(expr).getValue() == 0;
+ stack.pop_back();
+ break;
+ }
+ case AffineExprKind::DimId: {
+ result = false;
+ stack.pop_back();
+ break;
+ }
+ case AffineExprKind::SymbolId: {
+ result = cast<AffineSymbolExpr>(expr).getPosition() == symbolPos;
+ stack.pop_back();
+ break;
+ }
+ // Checks divisibility by the given symbol for both operands.
+ case AffineExprKind::Add: {
+ AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
+ stack.emplace_back(binaryExpr.getLHS(), symbolPos, opKind,
+ [&stack, &result, binaryExpr, symbolPos, opKind]() {
+ if (result) {
+ stack.emplace_back(
+ binaryExpr.getRHS(), symbolPos, opKind,
+ [&stack]() { stack.pop_back(); });
+ } else {
+ stack.pop_back();
+ }
+ });
+ break;
+ }
+ // Checks divisibility by the given symbol for both operands. Consider the
+ // expression `(((s1*s0) floordiv w) mod ((s1 * s2) floordiv p)) floordiv
+ // s1`, this is a division by s1 and both the operands of modulo are
+ // divisible by s1 but it is not divisible by s1 always. The third argument
+ // is `AffineExprKind::Mod` for this reason.
+ case AffineExprKind::Mod: {
+ AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
+ stack.emplace_back(binaryExpr.getLHS(), symbolPos, AffineExprKind::Mod,
+ [&stack, &result, binaryExpr, symbolPos, opKind]() {
+ if (result) {
+ stack.emplace_back(
+ binaryExpr.getRHS(), symbolPos,
+ AffineExprKind::Mod,
+ [&stack]() { stack.pop_back(); });
+ } else {
+ stack.pop_back();
+ }
+ });
+ break;
+ }
+ // Checks if any of the operand divisible by the given symbol.
+ case AffineExprKind::Mul: {
+ AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
+ stack.emplace_back(binaryExpr.getLHS(), symbolPos, opKind,
+ [&stack, &result, binaryExpr, symbolPos, opKind]() {
+ if (!result) {
+ stack.emplace_back(
+ binaryExpr.getRHS(), symbolPos, opKind,
+ [&stack]() { stack.pop_back(); });
+ } else {
+ stack.pop_back();
+ }
+ });
+ break;
+ }
+ // Floordiv and ceildiv are divisible by the given symbol when the first
+ // operand is divisible, and the affine expression kind of the argument expr
+ // is same as the argument `opKind`. This can be inferred from commutative
+ // property of floordiv and ceildiv operations and are as follow:
+ // (exp1 floordiv exp2) floordiv exp3 = (exp1 floordiv exp3) floordiv exp2
+ // (exp1 ceildiv exp2) ceildiv exp3 = (exp1 ceildiv exp3) ceildiv expr2
+ // It will fail 1.if operations are not same. For example:
+ // (exps1 ceildiv exp2) floordiv exp3 can not be simplified. 2.if there is a
+ // multiplication operation in the expression. For example:
+ // (exps1 ceildiv exp2) mul exp3 ceildiv exp4 can not be simplified.
+ case AffineExprKind::FloorDiv:
+ case AffineExprKind::CeilDiv: {
+ AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
+ if (opKind != expr.getKind()) {
+ result = false;
+ stack.pop_back();
+ break;
+ }
+ if (llvm::any_of(stack, [](auto &it) {
+ return std::get<0>(it).getKind() == AffineExprKind::Mul;
+ })) {
+ result = false;
+ stack.pop_back();
+ break;
+ }
+
+ stack.emplace_back(binaryExpr.getLHS(), symbolPos, expr.getKind(),
+ [&stack]() { stack.pop_back(); });
+ break;
+ }
+ llvm_unreachable("Unknown AffineExpr");
+ }
}
- llvm_unreachable("Unknown AffineExpr");
+ return result;
}
/// Divides the given expression by the given symbol at position `symbolPos`. It
diff --git a/mlir/test/Dialect/Affine/simplify-structures.mlir b/mlir/test/Dialect/Affine/simplify-structures.mlir
index 92d3d86bc93068..d1f34f20fa5dad 100644
--- a/mlir/test/Dialect/Affine/simplify-structures.mlir
+++ b/mlir/test/Dialect/Affine/simplify-structures.mlir
@@ -308,10 +308,26 @@ func.func @semiaffine_ceildiv(%arg0: index, %arg1: index) -> index {
}
// Tests the simplification of a semi-affine expression with a nested ceildiv operation and further simplifications after performing ceildiv.
-// CHECK-LABEL: func @semiaffine_composite_floor
-func.func @semiaffine_composite_floor(%arg0: index, %arg1: index) -> index {
+// CHECK-LABEL: func @semiaffine_composite_ceildiv
+func.func @semiaffine_composite_ceildiv(%arg0: index, %arg1: index) -> index {
+ %a = affine.apply affine_map<(d0)[s0] ->((((s0 * 2) ceildiv 4) + s0 * 42) ceildiv s0)> (%arg0)[%arg1]
+ // CHECK: %[[CST:.*]] = arith.constant 43
+ return %a : index
+}
+
+// Tests the do not simplification of a semi-affine expression with a nested ceildiv-mul-ceildiv operation.
+// CHECK-LABEL: func @semiaffine_composite_ceildiv
+func.func @semiaffine_composite_ceildiv_mul_ceildiv(%arg0: index, %arg1: index) -> index {
%a = affine.apply affine_map<(d0)[s0] ->(((((s0 * 2) ceildiv 4) * 5) + s0 * 42) ceildiv s0)> (%arg0)[%arg1]
- // CHECK: %[[CST:.*]] = arith.constant 47
+ // CHECK-NOT: arith.constant
+ return %a : index
+}
+
+// Tests the do not simplification of a semi-affine expression with a nested floordiv_mul_floordiv operation
+// CHECK-LABEL: func @semiaffine_composite_floordiv
+func.func @semiaffine_composite_floordiv_mul_floordiv(%arg0: index, %arg1: index) -> index {
+ %a = affine.apply affine_map<(d0)[s0] ->(((((s0 * 2) floordiv 4) * 5) + s0 * 42) floordiv s0)> (%arg0)[%arg1]
+ // CHECK-NOT: arith.constant
return %a : index
}
|
af8be8c to
21587f8
Compare
|
Hey, I see you're still actively changing this. Let me know when it's ready for review. |
Ready for review now. |
0ed38c8 to
4410645
Compare
4410645 to
389dc01
Compare
389dc01 to
e096a49
Compare
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
e096a49 to
437549c
Compare
…ot satisfying commutative Fixes llvm#107508
437549c to
c70e794
Compare
| return static_cast<ImplType *>(expr)->position; | ||
| } | ||
|
|
||
| /// A manually managed stack used to convert recursive function calls into |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not just use AffineExpr::walk? It'll do a post-order traversal, so you could just have a stack of partial results. Unless I'm still missing something :-).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This requires pre order traversal. Perhaps we can refactor Affine Visitor to support pre order traversal.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you give me a small example/explanation where pre-order traversal is needed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry,I didn't express clearly. The implementation of this isDivisibleBySymbol requires a pre-order traversal.
eg:
case add: visit(lhs) and vist(rhs)
case mul: visit(lhs) or vist(rhs)
base on the current expr type, using a control traversal approach, perhaps returning the interrupting and boolean types in the visitor can meet this requirement, but it also needs to be reimplemented.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In a post-order traversal with a stack for intermediates, you'd have the results for lhs and rhs on the top of the stack when you visit an add, right? Can you give me a more complete example where this doesn't work?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see no problem with using an explicit stack in addition to the implicit one in ::walk.
If you're concerned about stack overflows, the correct thing to do is to refactor ::walk to not be recursive, since that function is used in many places, so fixing this particular location isn't really sufficient if stack overflows are an issue with ::walk.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Refactoring MLIR walk() to not be recursive is a long-lasting issue: patch welcome to fix this!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok.I will switch to a ‘walk’ approach and then refactor AffineExprVisitor in another patch using a stack to avoid recursive calls.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why don't use walk,
This is walk implementation.At this point,lhs and rhs have already been traversed.and I don't know where to do the push stack.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@joker-eph The walk here is the traversal method of AffineExpr. When you mention mlir:: walk, do you mean about mlir::Operation?
Fixes #107508